{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calibrate MegaDetector v4.1\n",
    "\n",
    "Creates calibration plot. Runs isotonic calibration and saves calibration function parameters. See last cell for how to load calibration function parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import sklearn.isotonic\n",
    "from tqdm import tqdm\n",
    "\n",
    "from detection.detector_eval import detector_eval  # requires TF ODAPI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# path to detections JSON in Batch API output format\n",
    "DETECTIONS_JSON_PATH = 'mdv4_1_detections_on_test.json'\n",
    "\n",
    "# path to list of results queried from MegaDB\n",
    "LABELS_JSON_PATH = 'mdv4_1_labels_on_test.json'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(DETECTIONS_JSON_PATH, 'r') as f:\n",
    "    detections_js = json.load(f)\n",
    "\n",
    "with open(LABELS_JSON_PATH, 'r') as f:\n",
    "    labels_js = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gt_db_dict = {\n",
    "    img['download_id'] + '.jpg': img for img in labels_js\n",
    "}\n",
    "detection_res = {\n",
    "    os.path.basename(img['file']): img for img in detections_js['images']\n",
    "}\n",
    "label_id_to_name = {\n",
    "    int(cat_id): name for cat_id, name in detections_js['detection_categories'].items()\n",
    "}\n",
    "label_map_name_to_id = {v: k for k, v in label_id_to_name.items()}\n",
    "display(label_map_name_to_id)\n",
    "assert set(detection_res.keys()) <= set(gt_db_dict.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "per_image_gts, per_image_detections = detector_eval.get_per_image_gts_and_detections(\n",
    "    gt_db_dict=gt_db_dict,\n",
    "    detection_res=detection_res,\n",
    "    label_map_name_to_id=label_map_name_to_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "per_cat_metrics = detector_eval.compute_precision_recall_bbox(\n",
    "    per_image_detections=per_image_detections,\n",
    "    per_image_gts=per_image_gts,\n",
    "    num_gt_classes=len(detections_js['detection_categories']),\n",
    "    matching_iou_threshold=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(15, 5), facecolor='white', tight_layout=True)\n",
    "has_legend = False\n",
    "for cat_id, ax in enumerate(axs, start=1):\n",
    "    plot_utils.plot_calibration_curve(\n",
    "        true_scores=per_cat_metrics[cat_id]['tp_fp'],\n",
    "        pred_scores=per_cat_metrics[cat_id]['scores'],\n",
    "        num_bins=15, ax=ax)\n",
    "\n",
    "    cat = label_id_to_name[cat_id]\n",
    "    ax.set_title(ax.get_title() + '\\n' + cat)\n",
    "    if not has_legend:\n",
    "        fig.legend(loc='upper left', bbox_to_anchor=(0.05, 0.85))\n",
    "        has_legend = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_pred_scores = np.concatenate([\n",
    "    per_cat_metrics[cat_id]['scores'] for cat_id in label_id_to_name\n",
    "])\n",
    "all_true_scores = np.concatenate([\n",
    "    per_cat_metrics[cat_id]['tp_fp'] for cat_id in label_id_to_name\n",
    "])\n",
    "\n",
    "calibrator = sklearn.isotonic.IsotonicRegression(y_min=0, y_max=1, increasing=True, out_of_bounds='raise')\n",
    "calibrator.fit(all_pred_scores, all_true_scores)\n",
    "# calibrator.f_ is a scipy.interpolate.interp1d object\n",
    "np.savez_compressed(\n",
    "    'mdv4_1_isotonic_calibration.npz',\n",
    "    x=calibrator.f_.x,\n",
    "    y=calibrator.f_.y)\n",
    "\n",
    "fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(15, 5), facecolor='white', tight_layout=True)\n",
    "has_legend = False\n",
    "for cat_id, ax in enumerate(axs, start=1):\n",
    "    true_scores = per_cat_metrics[cat_id]['tp_fp']\n",
    "    pred_scores = per_cat_metrics[cat_id]['scores']\n",
    "    plot_utils.plot_calibration_curve(\n",
    "        true_scores=true_scores,\n",
    "        pred_scores=pred_scores,\n",
    "        num_bins=15, name='uncalibrated outputs', ax=ax)\n",
    "    plot_utils.plot_calibration_curve(\n",
    "        true_scores=true_scores,\n",
    "        pred_scores=calibrator.transform(pred_scores),\n",
    "        num_bins=15, name='calibrated outputs', ax=ax,\n",
    "        plot_hist=False, plot_perf=False)\n",
    "\n",
    "    cat = label_id_to_name[cat_id]\n",
    "    ax.set_title(ax.get_title() + '\\n' + cat)\n",
    "    if not has_legend:\n",
    "        fig.legend(loc='upper left', bbox_to_anchor=(0.05, 0.85))\n",
    "        has_legend = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# as a sanity check, and to demonstrate how to load the isotonic calibration\n",
    "import scipy.interpolate\n",
    "\n",
    "with np.load('mdv4_1_isotonic_calibration.npz') as npz:\n",
    "    f = scipy.interpolate.interp1d(x=npz['x'], y=npz['y'], kind='linear')\n",
    "assert np.all(f(all_pred_scores) == calibrator.transform(all_pred_scores))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:cameratraps-classifier] *",
   "language": "python",
   "name": "conda-env-cameratraps-classifier-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
